{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export to ONNX and inference using TensorRT\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Note:</b>\n",
    "\n",
    "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n",
    "\n",
    "</div>\n",
    "\n",
    "Transformer Engine (TE) is a library designed primarily for training DL models in low precision. It is not specifically optimized for inference tasks, so other dedicated solutions should be used. NVIDIA provides several [inference tools](https://www.nvidia.com/en-us/solutions/ai/inference/) that enhance the entire inference pipeline. Two prominent NVIDIA inference SDKs are [TensorRT](https://github.com/NVIDIA/TensorRT) and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM).\n",
    "\n",
    "This tutorial illustrates how one can export a PyTorch model to ONNX format and subsequently perform inference with TensorRT. This approach is particularly beneficial if model integrates Transformer Engine layers within more complex architectures. It's important to highlight that for Transformer-based large language models (LLMs), TensorRT-LLM could provide a more optimized inference experience. However, the ONNX-to-TensorRT approach described here may be more suitable for other models, such as diffusion-based architectures or vision transformers.\n",
    "\n",
    "#### Creating models with TE\n",
    "\n",
    "Let's begin by defining a simple model composed of layers both from Transformer Engine and standard PyTorch:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch \n",
    "import torch.nn as nn\n",
    "import transformer_engine as te\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# batch size, sequence length, hidden dimension\n",
    "B, S, H = 256, 512, 256\n",
    "\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self, hidden_dim=H, num_non_te_layers=16, num_te_layers=4, num_te_heads=4):\n",
    "        super(Model, self).__init__()\n",
    "        self.non_te_part = nn.Sequential(\n",
    "            *[nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(num_non_te_layers)]\n",
    "        )\n",
    "        self.te_part = nn.Sequential(\n",
    "            *[te.pytorch.TransformerLayer(hidden_dim, hidden_dim, num_te_heads) for _ in range(num_te_layers)]\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.non_te_part(x)\n",
    "        return self.te_part(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's run some simple inference benchmarks:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average inference time FP32: 0.065 ms\n",
      "Average inference time FP8: 0.062 ms\n"
     ]
    }
   ],
   "source": [
    "from utils import  _measure_time\n",
    "\n",
    "model = Model().eval().cuda()\n",
    "inps = (torch.randn([S, B, H], device=\"cuda\"),)\n",
    "def _inference(fp8_enabled):\n",
    "     with torch.no_grad(), te.pytorch.autocast(enabled=fp8_enabled):\n",
    "        model(*inps)\n",
    "\n",
    "te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n",
    "te_fp8_time = _measure_time(lambda: _inference(fp8_enabled=True))\n",
    "\n",
    "print(f\"Average inference time FP32: {te_fp32_time} ms\")\n",
    "print(f\"Average inference time FP8: {te_fp8_time} ms\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Exporting the TE Model to ONNX Format\n",
    "\n",
    "PyTorch developed a new [ONNX exporter](https://pytorch.org/docs/stable/onnx.html) built on TorchDynamo and plans to phase out the existing TorchScript exporter. As this feature is currently in active development, we recommend running this process with the latest PyTorch version.\n",
    "\n",
    "\n",
    "To export a Transformer Engine model into ONNX format, follow these steps:\n",
    "\n",
    "- Conduct warm-up run within autocast using the recipe intended for export.\n",
    "- Encapsulate your export-related code within `te.onnx_export`, ensuring warm-up runs remain outside this wrapper.\n",
    "- Use the PyTorch Dynamo ONNX exporter by invoking: `torch.onnx.export(..., dynamo=True)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Exporting model_fp8.onnx\n",
      "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n",
      "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n",
      "[torch.onnx] Run decomposition...\n",
      "[torch.onnx] Run decomposition... ✅\n",
      "[torch.onnx] Translate the graph into ONNX...\n",
      "[torch.onnx] Translate the graph into ONNX... ✅\n",
      "Applied 12 of general pattern rewrite rules.\n",
      "Exporting model_fp32.onnx\n",
      "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n",
      "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n",
      "[torch.onnx] Run decomposition...\n",
      "[torch.onnx] Run decomposition... ✅\n",
      "[torch.onnx] Translate the graph into ONNX...\n",
      "[torch.onnx] Translate the graph into ONNX... ✅\n",
      "Applied 12 of general pattern rewrite rules.\n"
     ]
    }
   ],
   "source": [
    "from transformer_engine.pytorch.export import te_translation_table\n",
    "\n",
    "def export(model, fname, inputs, fp8=True):\n",
    "    with torch.no_grad(), te.pytorch.autocast(enabled=fp8):\n",
    "        # ! IMPORTANT !\n",
    "        # Transformer Engine models must have warm-up run\n",
    "        # before export. FP8 recipe during warm-up should  \n",
    "        # match the recipe used during export.\n",
    "        model(*inputs)\n",
    "    \n",
    "        # Only dynamo=True mode is supported;\n",
    "        # dynamo=False is deprecated and unsupported.\n",
    "        #\n",
    "        # te_translation_table contains necessary ONNX translations\n",
    "        # for FP8 quantize/dequantize operators.\n",
    "        print(f\"Exporting {fname}\")\n",
    "        with te.pytorch.onnx_export(enabled=True):\n",
    "            torch.onnx.export(\n",
    "                model,\n",
    "                inputs,\n",
    "                fname,\n",
    "                output_names=[\"output\"],\n",
    "                dynamo=True,\n",
    "                custom_translation_table=te_translation_table\n",
    "            )\n",
    "\n",
    "# Example usage:\n",
    "export(model, \"model_fp8.onnx\", inps, fp8=True)\n",
    "export(model, \"model_fp32.onnx\", inps, fp8=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Inference with TensorRT\n",
    "\n",
    "TensorRT is a high-performance deep learning inference optimizer and runtime developed by NVIDIA. It enables optimized deployment of neural network models by maximizing inference throughput and reducing latency on NVIDIA GPUs. TensorRT performs various optimization techniques, including layer fusion, precision calibration, kernel tuning, and memory optimization. \n",
    "For detailed information and documentation, refer to the official [TensorRT documentation](https://developer.nvidia.com/tensorrt).\n",
    "\n",
    "When using TensorRT, ONNX model must first be compiled into a TensorRT engine. This compilation step involves converting the ONNX model into an optimized representation tailored specifically to the target GPU platform. The compiled engine file can then be loaded into applications for rapid and efficient inference execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "!trtexec --onnx=model_fp32.onnx --saveEngine=model_fp32.engine > output_fp32.log 2>&1\n",
    "!trtexec --onnx=model_fp8.onnx --saveEngine=model_fp8.engine > output_fp8.log 2>&1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's run the benchmarks for inference:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average inference time without TRT (FP32 for all layers):                       0.065 ms\n",
      "Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): 0.062 ms, speedup = 1.05x\n",
      "Average inference time with TRT (FP32 for all layers):                          0.0500 ms, speedup = 1.30x\n",
      "Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers):    0.0470 ms, speedup = 1.38x\n"
     ]
    }
   ],
   "source": [
    "import tensorrt as trt\n",
    "\n",
    "# Output tensor is allocated - TRT needs static memory address.\n",
    "output_tensor = torch.empty_like(model(*inps))\n",
    "\n",
    "# Loads TRT engine from file.\n",
    "def load_engine(engine_file_path):\n",
    "    logger = trt.Logger(trt.Logger.WARNING)\n",
    "    runtime = trt.Runtime(logger)\n",
    "    \n",
    "    with open(engine_file_path, \"rb\") as f:\n",
    "        engine_data = f.read()\n",
    "    \n",
    "    engine = runtime.deserialize_cuda_engine(engine_data)\n",
    "    return engine\n",
    "\n",
    "def benchmark_inference(model_name):\n",
    "    engine = load_engine(model_name)\n",
    "    context = engine.create_execution_context()\n",
    "    stream = torch.cuda.Stream()\n",
    "    \n",
    "    # TRT need static input and output addresses.\n",
    "    # Here they are set.\n",
    "    for i in range(len(inps)):\n",
    "        context.set_tensor_address(engine.get_tensor_name(i), inps[i].data_ptr())    \n",
    "    context.set_tensor_address(\"output\", output_tensor.data_ptr())\n",
    "    \n",
    "    def _inference():\n",
    "        # The data is loaded from static input addresses\n",
    "        # and output is written to static output address.\n",
    "        context.execute_async_v3(stream_handle=stream.cuda_stream)\n",
    "        stream.synchronize()\n",
    "        \n",
    "    return _measure_time(_inference)\n",
    "\n",
    "\n",
    "trt_fp8_time = benchmark_inference(\"model_fp8.engine\")\n",
    "trt_fp32_time = benchmark_inference(\"model_fp32.engine\")\n",
    "\n",
    "print(f\"Average inference time without TRT (FP32 for all layers):                       {te_fp32_time} ms\")\n",
    "print(f\"Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): {te_fp8_time} ms, speedup = {te_fp32_time/te_fp8_time:.2f}x\")\n",
    "print(f\"Average inference time with TRT (FP32 for all layers):                          {trt_fp32_time:.4f} ms, speedup = {te_fp32_time/trt_fp32_time:.2f}x\")\n",
    "print(f\"Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers):    {trt_fp8_time:.4f} ms, speedup = {te_fp32_time/trt_fp8_time:.2f}x\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p>\n",
    "\n",
    "\n",
    "| Run                               | Inference Time (ms) | Speedup             |\n",
    "| ----------------------------------| ------------------- | ------------------- |\n",
    "| PyTorch + TE                      | 0.065               | 1.00x               |\n",
    "| PyTorch + TE (FP8 for TE layers)  | 0.062               | 1.05x               |\n",
    "| TRT                               | 0.0500              | 1.30x               |\n",
    "| TRT (FP8 for TE layers)           | 0.047               | 1.38x               |\n",
    "\n",
    "Note that this example highlights how TensorRT can speed up models composed of both TE and non-TE layers.\n",
    "If a larger part of the model's layers were implemented with TE, the benefits of using FP8 for inference could be greater.\n",
    "\n",
    "</p>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We clearly observe performance improvements when using FP8 and the TensorRT inference engine. These improvements may become even more significant with more complex models, as TensorRT could potentially identify additional optimization opportunities.\n",
    "\n",
    "#### Appendix: Low Precision Operators in ONNX and TensorRT\n",
    "\n",
    "The ONNX standard does not currently support all precision types provided by the Transformer Engine. All available ONNX operators are listed on [this website](https://onnx.ai/onnx/operators/). Consequently, TensorRT and the Transformer Engine utilize certain specialized low-precision operators, detailed below.\n",
    "\n",
    "**TRT_FP8_QUANTIZE**\n",
    "\n",
    "- **Name**: TRT_FP8_QUANTIZE\n",
    "- **Domain**: trt\n",
    "- **Inputs**:\n",
    "    - `x`: float32 tensor\n",
    "    - `scale`: float32 scalar\n",
    "- **Outputs**:\n",
    "    - `y`: int8 tensor\n",
    "\n",
    "Produces an int8 tensor that represents the binary encoding of FP8 values.\n",
    "\n",
    "**TRT_FP8_DEQUANTIZE**\n",
    "\n",
    "- **Name**: TRT_FP8_DEQUANTIZE\n",
    "- **Domain**: trt\n",
    "- **Inputs**:\n",
    "    - `x`: int8 tensor\n",
    "    - `scale`: float32 scalar\n",
    "- **Outputs**:\n",
    "    - `y`: float32 tensor\n",
    "\n",
    "Converts FP8-encoded int8 tensor data back into float32 precision.\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Note:</b>\n",
    "\n",
    "Since standard ONNX operators do not support certain input and output precision types, a workaround is employed: tensors are dequantized to higher precision (float32) before input into these operators or quantized to lower precision after processing. TensorRT recognizes such quantize-dequantize patterns and replaces them with optimized operations. More details are available in [this section](https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-quantized-types.html#tensorrt-processing-of-q-dq-networks) of the TensorRT documentation.\n",
    "\n",
    "</div>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
